{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A first CausalNex tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial will walk you through an example workflow using CausalNex to estimate whether a student will pass or fail an exam, by looking at various influences like school support, relationship between family members, and others. We will use the [Student Performance Data Set](https://archive.ics.uci.edu/ml/datasets/Student+Performance) published in the [UCI Machine Learning Repository](http://archive.ics.uci.edu/ml).\n",
    "\n",
    "\n",
    "To work through this tutorial, you first need to create a new Python 3 notebook and download the [student.zip](https://archive.ics.uci.edu/ml/machine-learning-databases/00320/student.zip) file and extract `student-por.csv` from the zip file into the same directory, then copy and paste the code cells from this tutorial into your notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Structure Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Defining the structure of a Bayesian Network (BN) model can be done based on machine learning, domain knowledge, or a combination of both, where experts and algorithms contribute as equal partners.\n",
    "\n",
    "Regardless of the approach, it is important to validate the structure by evaluating the BN - this will be covered later in the tutorial. In this section, we will focus on how to define a structure."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Structure from Domain Knowledge"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can manually define a structure model by specifying the relationships between different features.\n",
    "\n",
    "First, we must create an empty structure model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "from causalnex.structure import StructureModel\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")  # silence warnings\n",
    "\n",
    "sm = StructureModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can specify the relationships between features. For example, let's assume that experts tell us the following causal relationships are known (G1 is grade in semester 1):\n",
    "\n",
    "* `health` -> `absences`\n",
    "* `health` -> `G1`\n",
    "\n",
    "We can add these relationships into our structure model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sm.add_edges_from([\n",
    "    ('health', 'absences'),\n",
    "    ('health', 'G1')\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualising the Structure"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can examine a StructureModel by looking at the output of `sm.edges`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OutEdgeView([('health', 'absences'), ('health', 'G1')])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sm.edges"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "but it can often be more intuitive to visualise it. CausalNex provides a plotting module that allows us to do this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"600px\"\n",
       "            src=\"supporting_files/01_simple_plot.html\"\n",
       "            html=\"\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fab9ab46da0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE\n",
    "\n",
    "viz = plot_structure(\n",
    "    sm, \n",
    "    all_node_attributes=NODE_STYLE.WEAK,\n",
    "    all_edge_attributes=EDGE_STYLE.WEAK,\n",
    ")\n",
    "viz.show(\"supporting_files/01_simple_plot.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning the Structure"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As the number of variables grows, or when domain knowledge does not exist, it can be tedious to define a structure manually. We can use CausalNex to learn the structure model from data. The structure learning algorithm we are going to use here is the [NOTEARS](https://arxiv.org/abs/1803.01422) algorithm.\n",
    "\n",
    "When learning structure, we can use the entire dataset. Since structure should be considered as a joint effort between machine learning and domain experts, it is not always necessary to use a train / test split.\n",
    "\n",
    "But before we begin, we have to pre-process the data so that the NOTEARS algorithm can be used."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the Data for Structure Learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>school</th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>address</th>\n",
       "      <th>famsize</th>\n",
       "      <th>Pstatus</th>\n",
       "      <th>Medu</th>\n",
       "      <th>Fedu</th>\n",
       "      <th>Mjob</th>\n",
       "      <th>Fjob</th>\n",
       "      <th>...</th>\n",
       "      <th>famrel</th>\n",
       "      <th>freetime</th>\n",
       "      <th>goout</th>\n",
       "      <th>Dalc</th>\n",
       "      <th>Walc</th>\n",
       "      <th>health</th>\n",
       "      <th>absences</th>\n",
       "      <th>G1</th>\n",
       "      <th>G2</th>\n",
       "      <th>G3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>GP</td>\n",
       "      <td>F</td>\n",
       "      <td>18</td>\n",
       "      <td>U</td>\n",
       "      <td>GT3</td>\n",
       "      <td>A</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>at_home</td>\n",
       "      <td>teacher</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>11</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>GP</td>\n",
       "      <td>F</td>\n",
       "      <td>17</td>\n",
       "      <td>U</td>\n",
       "      <td>GT3</td>\n",
       "      <td>T</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>at_home</td>\n",
       "      <td>other</td>\n",
       "      <td>...</td>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "      <td>11</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>GP</td>\n",
       "      <td>F</td>\n",
       "      <td>15</td>\n",
       "      <td>U</td>\n",
       "      <td>LE3</td>\n",
       "      <td>T</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>at_home</td>\n",
       "      <td>other</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>6</td>\n",
       "      <td>12</td>\n",
       "      <td>13</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>GP</td>\n",
       "      <td>F</td>\n",
       "      <td>15</td>\n",
       "      <td>U</td>\n",
       "      <td>GT3</td>\n",
       "      <td>T</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>health</td>\n",
       "      <td>services</td>\n",
       "      <td>...</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>14</td>\n",
       "      <td>14</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>GP</td>\n",
       "      <td>F</td>\n",
       "      <td>16</td>\n",
       "      <td>U</td>\n",
       "      <td>GT3</td>\n",
       "      <td>T</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>other</td>\n",
       "      <td>other</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>11</td>\n",
       "      <td>13</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 33 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "  school sex  age address famsize Pstatus  Medu  Fedu     Mjob      Fjob  ...  \\\n",
       "0     GP   F   18       U     GT3       A     4     4  at_home   teacher  ...   \n",
       "1     GP   F   17       U     GT3       T     1     1  at_home     other  ...   \n",
       "2     GP   F   15       U     LE3       T     1     1  at_home     other  ...   \n",
       "3     GP   F   15       U     GT3       T     4     2   health  services  ...   \n",
       "4     GP   F   16       U     GT3       T     3     3    other     other  ...   \n",
       "\n",
       "  famrel freetime  goout  Dalc  Walc health absences  G1  G2  G3  \n",
       "0      4        3      4     1     1      3        4   0  11  11  \n",
       "1      5        3      3     1     1      3        2   9  11  11  \n",
       "2      4        3      2     2     3      3        6  12  13  12  \n",
       "3      3        2      2     1     1      5        0  14  14  14  \n",
       "4      4        3      2     1     2      5        0  11  13  13  \n",
       "\n",
       "[5 rows x 33 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data = pd.read_csv('student-por.csv', delimiter=';')\n",
    "data.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looking at the data, we can see that features consist of numeric and non-numeric columns. We can drop sensitive features such as sex that we do not want to include in our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>address</th>\n",
       "      <th>famsize</th>\n",
       "      <th>Pstatus</th>\n",
       "      <th>Medu</th>\n",
       "      <th>Fedu</th>\n",
       "      <th>traveltime</th>\n",
       "      <th>studytime</th>\n",
       "      <th>failures</th>\n",
       "      <th>schoolsup</th>\n",
       "      <th>famsup</th>\n",
       "      <th>...</th>\n",
       "      <th>famrel</th>\n",
       "      <th>freetime</th>\n",
       "      <th>goout</th>\n",
       "      <th>Dalc</th>\n",
       "      <th>Walc</th>\n",
       "      <th>health</th>\n",
       "      <th>absences</th>\n",
       "      <th>G1</th>\n",
       "      <th>G2</th>\n",
       "      <th>G3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>U</td>\n",
       "      <td>GT3</td>\n",
       "      <td>A</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>no</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>11</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>U</td>\n",
       "      <td>GT3</td>\n",
       "      <td>T</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>no</td>\n",
       "      <td>yes</td>\n",
       "      <td>...</td>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "      <td>11</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>U</td>\n",
       "      <td>LE3</td>\n",
       "      <td>T</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>yes</td>\n",
       "      <td>no</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>6</td>\n",
       "      <td>12</td>\n",
       "      <td>13</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>U</td>\n",
       "      <td>GT3</td>\n",
       "      <td>T</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>no</td>\n",
       "      <td>yes</td>\n",
       "      <td>...</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>14</td>\n",
       "      <td>14</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>U</td>\n",
       "      <td>GT3</td>\n",
       "      <td>T</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>no</td>\n",
       "      <td>yes</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>11</td>\n",
       "      <td>13</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 26 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "  address famsize Pstatus  Medu  Fedu  traveltime  studytime  failures  \\\n",
       "0       U     GT3       A     4     4           2          2         0   \n",
       "1       U     GT3       T     1     1           1          2         0   \n",
       "2       U     LE3       T     1     1           1          2         0   \n",
       "3       U     GT3       T     4     2           1          3         0   \n",
       "4       U     GT3       T     3     3           1          2         0   \n",
       "\n",
       "  schoolsup famsup  ... famrel freetime goout Dalc Walc health  absences  G1  \\\n",
       "0       yes     no  ...      4        3     4    1    1      3         4   0   \n",
       "1        no    yes  ...      5        3     3    1    1      3         2   9   \n",
       "2       yes     no  ...      4        3     2    2    3      3         6  12   \n",
       "3        no    yes  ...      3        2     2    1    1      5         0  14   \n",
       "4        no    yes  ...      4        3     2    1    2      5         0  11   \n",
       "\n",
       "   G2  G3  \n",
       "0  11  11  \n",
       "1  11  11  \n",
       "2  13  12  \n",
       "3  14  14  \n",
       "4  13  13  \n",
       "\n",
       "[5 rows x 26 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drop_col = ['school','sex','age','Mjob', 'Fjob','reason','guardian']\n",
    "data = data.drop(columns=drop_col)\n",
    "data.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we want to make our data numeric, since this is what the NOTEARS expects. We can do this by label encoding non-numeric variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['address', 'famsize', 'Pstatus', 'schoolsup', 'famsup', 'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic']\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "struct_data = data.copy()\n",
    "non_numeric_columns = list(struct_data.select_dtypes(exclude=[np.number]).columns)\n",
    "\n",
    "print(non_numeric_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>address</th>\n",
       "      <th>famsize</th>\n",
       "      <th>Pstatus</th>\n",
       "      <th>Medu</th>\n",
       "      <th>Fedu</th>\n",
       "      <th>traveltime</th>\n",
       "      <th>studytime</th>\n",
       "      <th>failures</th>\n",
       "      <th>schoolsup</th>\n",
       "      <th>famsup</th>\n",
       "      <th>...</th>\n",
       "      <th>famrel</th>\n",
       "      <th>freetime</th>\n",
       "      <th>goout</th>\n",
       "      <th>Dalc</th>\n",
       "      <th>Walc</th>\n",
       "      <th>health</th>\n",
       "      <th>absences</th>\n",
       "      <th>G1</th>\n",
       "      <th>G2</th>\n",
       "      <th>G3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>11</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "      <td>11</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>6</td>\n",
       "      <td>12</td>\n",
       "      <td>13</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>14</td>\n",
       "      <td>14</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>11</td>\n",
       "      <td>13</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 26 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   address  famsize  Pstatus  Medu  Fedu  traveltime  studytime  failures  \\\n",
       "0        1        0        0     4     4           2          2         0   \n",
       "1        1        0        1     1     1           1          2         0   \n",
       "2        1        1        1     1     1           1          2         0   \n",
       "3        1        0        1     4     2           1          3         0   \n",
       "4        1        0        1     3     3           1          2         0   \n",
       "\n",
       "   schoolsup  famsup  ...  famrel  freetime  goout  Dalc  Walc  health  \\\n",
       "0          1       0  ...       4         3      4     1     1       3   \n",
       "1          0       1  ...       5         3      3     1     1       3   \n",
       "2          1       0  ...       4         3      2     2     3       3   \n",
       "3          0       1  ...       3         2      2     1     1       5   \n",
       "4          0       1  ...       4         3      2     1     2       5   \n",
       "\n",
       "   absences  G1  G2  G3  \n",
       "0         4   0  11  11  \n",
       "1         2   9  11  11  \n",
       "2         6  12  13  12  \n",
       "3         0  14  14  14  \n",
       "4         0  11  13  13  \n",
       "\n",
       "[5 rows x 26 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "le = LabelEncoder()\n",
    "\n",
    "for col in non_numeric_columns:\n",
    "    struct_data[col] = le.fit_transform(struct_data[col])\n",
    "    \n",
    "struct_data.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now apply the NOTEARS algorithm to learn the structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from causalnex.structure.notears import from_pandas\n",
    "sm = from_pandas(struct_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and visualise the learned StructureModel using the plot function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"600px\"\n",
       "            src=\"supporting_files/01_fully_connected.html\"\n",
       "            html=\"\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fab9b3a1480>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "viz = plot_structure(\n",
    "    sm, \n",
    "    all_node_attributes=NODE_STYLE.WEAK,\n",
    "    all_edge_attributes=EDGE_STYLE.WEAK,\n",
    ")\n",
    "\n",
    "viz.toggle_physics(False)\n",
    "viz.show(\"supporting_files/01_fully_connected.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The reason why we have a fully connected graph here is we haven't applied thresholding to the weaker edges. Thresholding can be applied either by specifying the value for the parameter `w_threshold` in `from_pandas`, or we can remove the edges by calling the structure model function, `remove_edges_below_threshold`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"600px\"\n",
       "            src=\"supporting_files/01_thresholded.html\"\n",
       "            html=\"\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fab9b3d5ba0>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sm.remove_edges_below_threshold(0.8)\n",
    "viz = plot_structure(\n",
    "    sm, \n",
    "    all_node_attributes=NODE_STYLE.WEAK,\n",
    "    all_edge_attributes=EDGE_STYLE.WEAK,\n",
    ")\n",
    "viz.show(\"supporting_files/01_thresholded.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this structure, we can see that there are some relationships that appear intuitively correct:\n",
    "\n",
    "* `Pstatus` affects `famrel` - if parents live apart, the quality of family relationship may be poor as a result. \n",
    "* `internet` affects `absences` - The presence of internet at home may cause student to skip class.\n",
    "* `studytime` affects `G1` - longer studytime should have a positive impact on a student's result. \n",
    "\n",
    "However, there are some relationships that are certainly incorrect:\n",
    "\n",
    "* `higher` affects `Medu` (Mother's education) - this relationship does not make sense as students who wants to pursue higher education does not affect mother's education. It could be the other way round.\n",
    "\n",
    "To avoid these erroneous relationships, we can re-run structure learning with some added constraints:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "sm = from_pandas(struct_data, tabu_edges=[(\"higher\", \"Medu\")], w_threshold=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"600px\"\n",
       "            src=\"supporting_files/01_edge_added.html\"\n",
       "            html=\"\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fab9b39dd20>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "viz = plot_structure(\n",
    "    sm, \n",
    "    all_node_attributes=NODE_STYLE.WEAK,\n",
    "    all_edge_attributes=EDGE_STYLE.WEAK,\n",
    ")\n",
    "viz.show(\"supporting_files/01_edge_added.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Modifying the Structure"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To correct erroneous relationships, we can incorporate domain knowledge into the model after structure learning. We can modify the structure model through adding and deleting the edges. For example, we can add and remove edges as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sm.add_edge(\"failures\", \"G1\")\n",
    "sm.remove_edge(\"Pstatus\", \"G1\")\n",
    "sm.remove_edge(\"address\", \"G1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now visualise our updated structure to confirm it looks reasonable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"600px\"\n",
       "            src=\"supporting_files/01_modified_structure.html\"\n",
       "            html=\"\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fab9b3a63e0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "viz = plot_structure(\n",
    "    sm, \n",
    "    all_node_attributes=NODE_STYLE.WEAK,\n",
    "    all_edge_attributes=EDGE_STYLE.WEAK,\n",
    ")\n",
    "viz.show(\"supporting_files/01_modified_structure.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see there are two separate subgraphs here in the visualisation plot: `Dalc->Walc` and the other big subgraph. We can retrieve the largest subgraph easily by calling the StructureModel function `get_largest_subgraph()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"600px\"\n",
       "            src=\"supporting_files/01_largest_subgraph.html\"\n",
       "            html=\"\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fab9b3aaf50>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sm = sm.get_largest_subgraph()\n",
    "viz = plot_structure(\n",
    "    sm, \n",
    "    all_node_attributes=NODE_STYLE.WEAK,\n",
    "    all_edge_attributes=EDGE_STYLE.WEAK,\n",
    ")\n",
    "viz.show(\"supporting_files/01_largest_subgraph.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exporting the Structure\n",
    "\n",
    "It is worth noting that `StructureModel` extends `networkx.DiGraph` (see [here](https://github.com/quantumblacklabs/causalnex/blob/develop/causalnex/structure/structuremodel.py)). We can therefore use `networkx`'s read/write methods to import/export our `StructureModel`. For example, we can export to a graphviz `.dot` file using `write_dot` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "\n",
    "nx.drawing.nx_pydot.write_dot(sm, 'graph.dot')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To export to other formats, please refer to `networkx`'s [documentation](https://networkx.org/documentation/stable/reference/readwrite/index.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fitting the Conditional Distribution of the Bayesian Network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After deciding on how the final structure model should look, we can instantiate a `BayesianNetwork`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from causalnex.network import BayesianNetwork\n",
    "\n",
    "bn = BayesianNetwork(sm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready to move on to learning the conditional probability distribution of different features in the `BayesianNetwork`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the Discretised Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Bayesian Networks in CausalNex support only discrete distributions. Any continuous features, or features with a large number of categories, should be discretised prior to fitting the Bayesian Network. Models containing variables with many possible values will typically be badly fit, and exhibit poor performance.\n",
    "\n",
    "For example, consider P(G2 | G1), where G1 and G2 have possible values 0 to 20. The discrete conditional probability distribution is therefore specified using 21x21 (441) possible combinations - most of which we will be unlikely to observe.\n",
    "\n",
    "CausalNex provides a few helper methods to make discretisation easier. Let's start by reducing the number of categories in some of the categorical features by combining similar values. We will make numeric features categorical by discretisation, and then give the buckets meaningful labels."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cardinality of Categorical Features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To reduce the cardinality of categorical features we can define a map `{old_value: new_value}`, and use this to update the feature. For example, in the `studytime` feature, we make the studytime which is more than 2 (2 means 2 to 5 hours here, see https://archive.ics.uci.edu/ml/datasets/Student+Performance) into `long-studytime`, and the rest into `short-studytime`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "discretised_data = data.copy()\n",
    "\n",
    "data_vals = {col: data[col].unique() for col in data.columns}\n",
    "\n",
    "failures_map = {v: 'no-failure' if v == [0]\n",
    "                else 'have-failure' for v in data_vals['failures']}\n",
    "studytime_map = {v: 'short-studytime' if v in [1,2]\n",
    "                 else 'long-studytime' for v in data_vals['studytime']}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we have defined our maps `{old_value: new_value}` we can update each feature, applying the mapping transformation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "discretised_data[\"failures\"] = discretised_data[\"failures\"].map(failures_map)\n",
    "discretised_data[\"studytime\"] = discretised_data[\"studytime\"].map(studytime_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Discretising Numeric Features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make numeric features categorical, they must first be discretised. CausalNex provides a helper class `causalnex.discretiser.Discretiser`, which supports several discretisation methods. For our data the `fixed` method will be applied, providing static values that define the bucket boundaries. For example, `absences` will be discretised into the buckets < 1, 1 to 9, and >=10. Each bucket will be labelled as an integer from zero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from causalnex.discretiser import Discretiser\n",
    "\n",
    "discretised_data[\"absences\"] = Discretiser(method=\"fixed\", \n",
    "                          numeric_split_points=[1, 10]).transform(discretised_data[\"absences\"].values)\n",
    "discretised_data[\"G1\"] = Discretiser(method=\"fixed\", \n",
    "                          numeric_split_points=[10]).transform(discretised_data[\"G1\"].values)\n",
    "discretised_data[\"G2\"] = Discretiser(method=\"fixed\", \n",
    "                          numeric_split_points=[10]).transform(discretised_data[\"G2\"].values)\n",
    "discretised_data[\"G3\"] = Discretiser(method=\"fixed\", \n",
    "                          numeric_split_points=[10]).transform(discretised_data[\"G3\"].values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Labels for Numeric Features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make the discretised categories more readable, we can map the category labels onto something more meaningful in the same way that we mapped category feature values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "absences_map = {0: \"No-absence\", 1: \"Low-absence\", 2: \"High-absence\"}\n",
    "\n",
    "G1_map = {0: \"Fail\", 1: \"Pass\"}\n",
    "G2_map = {0: \"Fail\", 1: \"Pass\"}\n",
    "G3_map = {0: \"Fail\", 1: \"Pass\"}\n",
    "\n",
    "discretised_data[\"absences\"] = discretised_data[\"absences\"].map(absences_map)\n",
    "discretised_data[\"G1\"] = discretised_data[\"G1\"].map(G1_map)\n",
    "discretised_data[\"G2\"] = discretised_data[\"G2\"].map(G2_map)\n",
    "discretised_data[\"G3\"] = discretised_data[\"G3\"].map(G3_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train / Test Split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like many other machine learning models, we will use a train and test split to help us validate our findings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split 90% train and 10% test\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train, test = train_test_split(discretised_data, train_size=0.9, test_size=0.1, random_state=7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Probability"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the learnt structure model from earlier and the discretised data, we can now fit the probability distrbution of the Bayesian Network. The first step in this is specifying all of the states that each node can take. This can be done either from data, or providing a dictionary of node values. We use the full dataset here to avoid cases where states in our test set do not exist in the training set. For real-world applications, these states may need to be provided using the dictionary method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "bn = bn.fit_node_states(discretised_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit Conditional Probability Distributions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `fit_cpds` method of `BayesianNetwork` accepts a dataset to learn the conditional probablilty distributions (CPDs) of each node, along with a method of how to do this fit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "bn = bn.fit_cpds(train, method=\"BayesianEstimator\", bayes_prior=\"K2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we have the CPDs, we can inspect them through the `cpds` property, which is a dictionary of node->cpd."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>failures</th>\n",
       "      <th colspan=\"8\" halign=\"left\">have-failure</th>\n",
       "      <th colspan=\"8\" halign=\"left\">no-failure</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>higher</th>\n",
       "      <th colspan=\"4\" halign=\"left\">no</th>\n",
       "      <th colspan=\"4\" halign=\"left\">yes</th>\n",
       "      <th colspan=\"4\" halign=\"left\">no</th>\n",
       "      <th colspan=\"4\" halign=\"left\">yes</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>schoolsup</th>\n",
       "      <th colspan=\"2\" halign=\"left\">no</th>\n",
       "      <th colspan=\"2\" halign=\"left\">yes</th>\n",
       "      <th colspan=\"2\" halign=\"left\">no</th>\n",
       "      <th colspan=\"2\" halign=\"left\">yes</th>\n",
       "      <th colspan=\"2\" halign=\"left\">no</th>\n",
       "      <th colspan=\"2\" halign=\"left\">yes</th>\n",
       "      <th colspan=\"2\" halign=\"left\">no</th>\n",
       "      <th colspan=\"2\" halign=\"left\">yes</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>studytime</th>\n",
       "      <th>long-studytime</th>\n",
       "      <th>short-studytime</th>\n",
       "      <th>long-studytime</th>\n",
       "      <th>short-studytime</th>\n",
       "      <th>long-studytime</th>\n",
       "      <th>short-studytime</th>\n",
       "      <th>long-studytime</th>\n",
       "      <th>short-studytime</th>\n",
       "      <th>long-studytime</th>\n",
       "      <th>short-studytime</th>\n",
       "      <th>long-studytime</th>\n",
       "      <th>short-studytime</th>\n",
       "      <th>long-studytime</th>\n",
       "      <th>short-studytime</th>\n",
       "      <th>long-studytime</th>\n",
       "      <th>short-studytime</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>G1</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Fail</th>\n",
       "      <td>0.75</td>\n",
       "      <td>0.806452</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.75</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.612245</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.75</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.612903</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.032967</td>\n",
       "      <td>0.15016</td>\n",
       "      <td>0.111111</td>\n",
       "      <td>0.255814</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Pass</th>\n",
       "      <td>0.25</td>\n",
       "      <td>0.193548</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.387755</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.387097</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.967033</td>\n",
       "      <td>0.84984</td>\n",
       "      <td>0.888889</td>\n",
       "      <td>0.744186</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "failures    have-failure                                                 \\\n",
       "higher                no                                                  \n",
       "schoolsup             no                            yes                   \n",
       "studytime long-studytime short-studytime long-studytime short-studytime   \n",
       "G1                                                                        \n",
       "Fail                0.75        0.806452            0.5            0.75   \n",
       "Pass                0.25        0.193548            0.5            0.25   \n",
       "\n",
       "failures                                                                 \\\n",
       "higher               yes                                                  \n",
       "schoolsup             no                            yes                   \n",
       "studytime long-studytime short-studytime long-studytime short-studytime   \n",
       "G1                                                                        \n",
       "Fail                 0.5        0.612245            0.5            0.75   \n",
       "Pass                 0.5        0.387755            0.5            0.25   \n",
       "\n",
       "failures      no-failure                                                 \\\n",
       "higher                no                                                  \n",
       "schoolsup             no                            yes                   \n",
       "studytime long-studytime short-studytime long-studytime short-studytime   \n",
       "G1                                                                        \n",
       "Fail                 0.5        0.612903            0.5             0.5   \n",
       "Pass                 0.5        0.387097            0.5             0.5   \n",
       "\n",
       "failures                                                                 \n",
       "higher               yes                                                 \n",
       "schoolsup             no                            yes                  \n",
       "studytime long-studytime short-studytime long-studytime short-studytime  \n",
       "G1                                                                       \n",
       "Fail            0.032967         0.15016       0.111111        0.255814  \n",
       "Pass            0.967033         0.84984       0.888889        0.744186  "
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bn.cpds[\"G1\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The CPD dictionaries are multi-indexed, and so the `loc` function can be a useful way to interact with them:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict the State given the Input Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `predict` method of `BayesianNetwork` allows us to make predictions based on the data using the learnt Bayesian Network. For example, we want to predict if a student fails or passes their exam based on the input data. Imagine we have an incoming student data that looks like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "address                     U\n",
       "famsize                   GT3\n",
       "Pstatus                     T\n",
       "Medu                        3\n",
       "Fedu                        2\n",
       "traveltime                  1\n",
       "studytime     short-studytime\n",
       "failures         have-failure\n",
       "schoolsup                  no\n",
       "famsup                    yes\n",
       "paid                      yes\n",
       "activities                yes\n",
       "nursery                   yes\n",
       "higher                    yes\n",
       "internet                  yes\n",
       "romantic                   no\n",
       "famrel                      5\n",
       "freetime                    5\n",
       "goout                       5\n",
       "Dalc                        2\n",
       "Walc                        4\n",
       "health                      5\n",
       "absences          Low-absence\n",
       "G2                       Fail\n",
       "G3                       Fail\n",
       "Name: 18, dtype: object"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "discretised_data.loc[18, discretised_data.columns != 'G1']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Based on these data, we want to predict if this student fails their exam. Intuitively, we would expect this student to fail because they spend a shorter amount of time on their study and have failed in the past. Let's see how our Bayesian Network performs on this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = bn.predict(discretised_data, \"G1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The prediction is 'Fail'\n"
     ]
    }
   ],
   "source": [
    "print(f\"The prediction is '{predictions.loc[18, 'G1_prediction']}'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The prediction by the Bayesian Network turns out to be a `Fail`. Let's compare this to the ground truth:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The ground truth is 'Fail'\n"
     ]
    }
   ],
   "source": [
    "print(f\"The ground truth is '{discretised_data.loc[18, 'G1']}'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "which turns out to be the same."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Quality"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To evaluate the quality of the model that has been learned, CausalNex supports two main approaches: Classification Report and Receiver Operating Characteristics (ROC) / Area Under the ROC Curve (AUC). In this section each will be discussed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification Report"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To obtain a classification report using a BN, we need to provide a test set, and the node we are trying to classify. The report will predict the target node for all rows in the test set, and evaluate how well those predictions are made."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'G1_Fail': {'precision': 0.7777777777777778,\n",
       "  'recall': 0.5833333333333334,\n",
       "  'f1-score': 0.6666666666666666,\n",
       "  'support': 12},\n",
       " 'G1_Pass': {'precision': 0.9107142857142857,\n",
       "  'recall': 0.9622641509433962,\n",
       "  'f1-score': 0.9357798165137615,\n",
       "  'support': 53},\n",
       " 'accuracy': 0.8923076923076924,\n",
       " 'macro avg': {'precision': 0.8442460317460317,\n",
       "  'recall': 0.7727987421383649,\n",
       "  'f1-score': 0.8012232415902141,\n",
       "  'support': 65},\n",
       " 'weighted avg': {'precision': 0.8861721611721611,\n",
       "  'recall': 0.8923076923076924,\n",
       "  'f1-score': 0.8860973888496825,\n",
       "  'support': 65}}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from causalnex.evaluation import classification_report\n",
    "\n",
    "classification_report(bn, test, \"G1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This report shows that the model we have defined is able to classify whether a student passes their exam reasonably well.\n",
    "\n",
    "For the predictions where the student fails, the precision is good, but recall is bad. This implies that we can rely on predictions for this class when they are made, but we are likely to miss some of the predictions we should have made. Perhaps these missing predictions are as a result of something missing in our structure - this could be an interesting area to explore."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ROC / AUC"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Receiver Operating Characteristics (ROC), and the Area Under the ROC Curve (AUC) can be obtained using the `roc_auc` method within the CausalNex metrics module. Again, a test set and target node must be provided. The ROC curve is computed by micro-averaging predictions made across all states (classes) of the target node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9181065088757396\n"
     ]
    }
   ],
   "source": [
    "from causalnex.evaluation import roc_auc\n",
    "roc, auc = roc_auc(bn, test, \"G1\")\n",
    "print(auc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The AUC value for our model is high, giving us confidence in the performance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Querying Marginals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After iterating over our model structure, CPDs, and validating our model quality, we can query our model under different observations to gain insights."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Baseline Marginals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To query the model for baseline marginals that reflect the population as a whole, a `query` method can be used. First let's update our model using the complete dataset, since the one we currently have was only built from training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "bn = bn.fit_cpds(discretised_data, method=\"BayesianEstimator\", bayes_prior=\"K2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can safely ignore these warnings, which let us know we are replacing the previously existing CPDs. \n",
    "\n",
    "For inference, we must create a new InferenceEngine from our BayesianNetwork, which lets us query the model. The query method will compute the marginal likelihood of all states for all nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Fail': 0.25260687281677224, 'Pass': 0.7473931271832277}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from causalnex.inference import InferenceEngine\n",
    "\n",
    "ie = InferenceEngine(bn)\n",
    "marginals = ie.query()\n",
    "marginals[\"G1\"] "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output observed tells us that `P(G1=Fail) = 0.25`, and the `P(G1=Pass) = 0.75`. As a quick sanity check, we can compute what proportion of our dataset are `Fail`, which should be approximately the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Fail', 157), ('Pass', 492)]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "labels, counts = np.unique(discretised_data[\"G1\"], return_counts=True)\n",
    "list(zip(labels, counts))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The proportion of the students who fail is `157 / (157+492) = 0.242` - which is close to our computed marginal likelihood."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Marginals after Observations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also query the marginal likelihood of states in our network given some observations. These observations can be made anywhere in the network, and their impact will be propagated through to the node of interest.\n",
    "\n",
    "Let's look at the difference in the likelihood of `G1` based on `studytime`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Marginal G1 | Short Studtyime {'Fail': 0.2776556433482524, 'Pass': 0.7223443566517477}\n",
      "Marginal G1 | Long Studytime {'Fail': 0.15504850337837614, 'Pass': 0.8449514966216239}\n"
     ]
    }
   ],
   "source": [
    "marginals_short = ie.query({\"studytime\": \"short-studytime\"})\n",
    "marginals_long = ie.query({\"studytime\": \"long-studytime\"})\n",
    "print(\"Marginal G1 | Short Studtyime\", marginals_short[\"G1\"])\n",
    "print(\"Marginal G1 | Long Studytime\", marginals_long[\"G1\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Based on our data we can see that students who study longer are more likely to pass their exam."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Do Calculus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CausalNex also supports simple Do-Calculus, allowing us to specify interventions. In this section we will take a look at the supported methods, and what they mean."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Updating a Node Distribution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can apply an intervention to any node in our data, updating its distribution using a `do` operator. This can be thought of as asking our model \"What if\" something were different. For example, we could ask what would happen if 100% of students wanted to go on to do higher education."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "distribution before do {'no': 0.10752688172043011, 'yes': 0.8924731182795698}\n",
      "distribution after do {'no': 0.0, 'yes': 0.9999999999999998}\n"
     ]
    }
   ],
   "source": [
    "print(\"distribution before do\", ie.query()[\"higher\"])\n",
    "ie.do_intervention(\"higher\", \n",
    "                   {'yes': 1.0, \n",
    "                    'no': 0.0})\n",
    "print(\"distribution after do\", ie.query()[\"higher\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Resetting a Node Distribution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can reset any interventions that we make by using the `reset_intervention` method, and providing the node that we want to reset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "ie.reset_do(\"higher\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Effect of Do on Marginals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can again use `query` to examine the effect that an intervention has on our marginal likelihoods. In this case, we can look at how the likelihood of achieving a pass changes if 100% of students wanted to do higher education."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "marginal G1 {'Fail': 0.25260687281677224, 'Pass': 0.7473931271832277}\n",
      "updated marginal G1 {'Fail': 0.20682952942551894, 'Pass': 0.7931704705744809}\n"
     ]
    }
   ],
   "source": [
    "print(\"marginal G1\", ie.query()[\"G1\"])\n",
    "ie.do_intervention(\"higher\", \n",
    "                   {'yes': 1.0, \n",
    "                    'no': 0.0})\n",
    "print(\"updated marginal G1\", ie.query()[\"G1\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, we can see that if 100% of students wanted to do higher education (as opposed to 90% in our data population), then we estimate that pass rate would increase from 74.7% to 79.3%."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "private-causalnex",
   "language": "python",
   "name": "private-causalnex"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
